Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch CUBLAS to device-side pointer mode #2616

Open
wants to merge 18 commits into
base: master
Choose a base branch
from
Open

Conversation

kshyatt
Copy link
Contributor

@kshyatt kshyatt commented Jan 10, 2025

Attempting to address #2571

I've set the pointer mode to "device side" during handle creation. Since gemmGroupedBatched doesn't support device side pointer mode, it won't be usable. One workaround for this would be to add a new function to create a handle with host side mode, or add the pointer mode as an optional kwarg to handle(). Very open to feedback on this.

I've set this up so that users can supply CuRefs of the appropriate result type to the level 1 functions for results. If that's not provided, the functions execute as they do today (synchronously). Similarly, for functions taking alpha or beta scalar arguments, if the user provides CuRef (actually a CuRefArray), the functions will execute asynchronously and return instantly. If the user provides a Number, the behaviour is unchanged from today. I'm not married to this design and it can certainly be changed.

cc @Jutho

@kshyatt kshyatt requested a review from maleadt January 10, 2025 21:03
@kshyatt kshyatt added the cuda libraries Stuff about CUDA library wrappers. label Jan 10, 2025
@kshyatt
Copy link
Contributor Author

kshyatt commented Jan 10, 2025

I can also add some more @eval blocks to try to cut down on the repetitive fallback logic

@kshyatt
Copy link
Contributor Author

kshyatt commented Jan 10, 2025

Sample speedup:

julia> using CUDA, CUDA.CUBLAS, LinearAlgebra;

julia> n = Int(2^26);

julia> X = CUDA.rand(Float64, n);

julia> res = CuRef{Float64}(0.0);

# do some precompilation runs first

julia> @time CUBLAS.nrm2(n, X, res);
  0.000104 seconds (18 allocations: 288 bytes)

julia> @time CUBLAS.nrm2(n, X);
  0.001564 seconds (73 allocations: 3.094 KiB)

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CUDA.jl Benchmarks

Benchmark suite Current: 252cfe6 Previous: 4d85f27 Ratio
latency/precompile 46530811171 ns 46476366460.5 ns 1.00
latency/ttfp 7004355347 ns 7027130025 ns 1.00
latency/import 3652355669 ns 3655499138 ns 1.00
integration/volumerhs 9626295.5 ns 9611290.5 ns 1.00
integration/byval/slices=1 146966 ns 147002 ns 1.00
integration/byval/slices=3 425312 ns 425499 ns 1.00
integration/byval/reference 144916 ns 145017 ns 1.00
integration/byval/slices=2 286014 ns 286202 ns 1.00
integration/cudadevrt 103349 ns 103544 ns 1.00
kernel/indexing 14256 ns 14154 ns 1.01
kernel/indexing_checked 14836 ns 14891 ns 1.00
kernel/occupancy 695.0987261146497 ns 635.3491124260355 ns 1.09
kernel/launch 2071.7 ns 2092.7 ns 0.99
kernel/rand 16944 ns 16799 ns 1.01
array/reverse/1d 19769 ns 19774 ns 1.00
array/reverse/2d 25287 ns 23706 ns 1.07
array/reverse/1d_inplace 11099 ns 10252 ns 1.08
array/reverse/2d_inplace 11463 ns 11923 ns 0.96
array/copy 20788 ns 21376 ns 0.97
array/iteration/findall/int 155824 ns 157809.5 ns 0.99
array/iteration/findall/bool 134846 ns 137133 ns 0.98
array/iteration/findfirst/int 153969 ns 154007.5 ns 1.00
array/iteration/findfirst/bool 153051 ns 154021 ns 0.99
array/iteration/scalar 61546.5 ns 61486 ns 1.00
array/iteration/logical 204856.5 ns 206205.5 ns 0.99
array/iteration/findmin/1d 38563 ns 39592 ns 0.97
array/iteration/findmin/2d 93728 ns 93947 ns 1.00
array/reductions/reduce/1d 37951 ns 38635 ns 0.98
array/reductions/reduce/2d 44823 ns 48115.5 ns 0.93
array/reductions/mapreduce/1d 36616 ns 36953 ns 0.99
array/reductions/mapreduce/2d 50826.5 ns 51316.5 ns 0.99
array/broadcast 20979 ns 20996 ns 1.00
array/copyto!/gpu_to_gpu 11736 ns 11843 ns 0.99
array/copyto!/cpu_to_gpu 209833 ns 209712 ns 1.00
array/copyto!/gpu_to_cpu 243177 ns 242645 ns 1.00
array/accumulate/1d 108711 ns 109015 ns 1.00
array/accumulate/2d 80196 ns 80178 ns 1.00
array/construct 1278.75 ns 1235.35 ns 1.04
array/random/randn/Float32 44050.5 ns 44506 ns 0.99
array/random/randn!/Float32 26631 ns 26892 ns 0.99
array/random/rand!/Int64 27088 ns 27134 ns 1.00
array/random/rand!/Float32 8611.333333333334 ns 8789.5 ns 0.98
array/random/rand/Int64 29911 ns 29973 ns 1.00
array/random/rand/Float32 13155 ns 13171 ns 1.00
array/permutedims/4d 60971 ns 61738.5 ns 0.99
array/permutedims/2d 55535 ns 55614 ns 1.00
array/permutedims/3d 56405 ns 56453.5 ns 1.00
array/sorting/1d 2775868 ns 2766804.5 ns 1.00
array/sorting/by 3367069 ns 3355439 ns 1.00
array/sorting/2d 1084319 ns 1082411 ns 1.00
cuda/synchronization/stream/auto 1031.9 ns 1054.7 ns 0.98
cuda/synchronization/stream/nonblocking 6528.9 ns 6495 ns 1.01
cuda/synchronization/stream/blocking 844.9014084507043 ns 810.1145833333334 ns 1.04
cuda/synchronization/context/auto 1260.4 ns 1191.2 ns 1.06
cuda/synchronization/context/nonblocking 6770.2 ns 6788.8 ns 1.00
cuda/synchronization/context/blocking 938.2666666666667 ns 945.7926829268292 ns 0.99

This comment was automatically generated by workflow using github-action-benchmark.

lib/cublas/wrappers.jl Outdated Show resolved Hide resolved
@kshyatt
Copy link
Contributor Author

kshyatt commented Jan 11, 2025 via email

@kshyatt
Copy link
Contributor Author

kshyatt commented Jan 11, 2025

Is the test failure something I've done? Seems GPUArrays related

@kshyatt kshyatt force-pushed the ksh/device_side branch 2 times, most recently from a0829fa to 5d52d10 Compare January 16, 2025 16:05
@kshyatt
Copy link
Contributor Author

kshyatt commented Jan 16, 2025

OK, I think this is ready for review!

@Jutho
Copy link
Contributor

Jutho commented Jan 16, 2025

I am not qualified to review, but certainly interested in the outcome. Will the non-blocking methods only accept CuRef objects for the scalar input or output quantities, or also zero-dimensional arrays (i.e. CuArray{T,0})?

@kshyatt
Copy link
Contributor Author

kshyatt commented Jan 16, 2025 via email

@kshyatt
Copy link
Contributor Author

kshyatt commented Jan 16, 2025

You can create a CuRefArray{T} where T is some element type from a single element CuVector. In fact, CuRef itself does this under the hood.

Copy link
Member

@maleadt maleadt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should also improve CuRef to initialize its memory by calling fill instead of memcpy: When calling memcpy, the copy likely won't be truly asynchronous (that would require pinned memory). But if we call fill, which should be possible for most scalars, the argument is passed by value and I think the call will complete asynchronously.
Something to investigate!

lib/cublas/wrappers.jl Outdated Show resolved Hide resolved
lib/cublas/wrappers.jl Outdated Show resolved Hide resolved
lib/cublas/wrappers.jl Outdated Show resolved Hide resolved
@maleadt
Copy link
Member

maleadt commented Jan 17, 2025

Something to investigate!

#2625

github-actions[bot]

This comment was marked as off-topic.

Copy link
Contributor

github-actions bot commented Jan 20, 2025

Your PR requires formatting changes to meet the project's style guidelines.
Please consider running Runic (git runic master) to apply these changes.

Click here to view the suggested changes.
diff --git a/lib/cublas/wrappers.jl b/lib/cublas/wrappers.jl
index 43ddfeaea..a5eb81c92 100644
--- a/lib/cublas/wrappers.jl
+++ b/lib/cublas/wrappers.jl
@@ -115,7 +115,7 @@ for (fname, fname_64, elty) in ((:cublasDscal_v2, :cublasDscal_v2_64, :Float64),
                                 (:cublasCscal_v2, :cublasCscal_v2_64, :ComplexF32))
     @eval begin
         function scal!(n::Integer,
-                       alpha::Ref{$elty},
+                alpha::Ref{$elty},
                        x::StridedCuVecOrDenseMat{$elty})
             if CUBLAS.version() >= v"12.0"
                 $fname_64(handle(), n, alpha, x, stride(x, 1))
@@ -147,7 +147,7 @@ for (fname, fname_64, elty, celty) in ((:cublasCsscal_v2, :cublasCsscal_v2_64, :
                                        (:cublasZdscal_v2, :cublasZdscal_v2_64, :Float64, :ComplexF64))
     @eval begin
         function scal!(n::Integer,
-                       alpha::Ref{$elty},
+                alpha::Ref{$elty},
                        x::StridedCuVecOrDenseMat{$celty})
             if CUBLAS.version() >= v"12.0"
                 $fname_64(handle(), n, alpha, x, stride(x, 1))
@@ -190,8 +190,8 @@ for (jname, fname, fname_64, elty) in ((:dot, :cublasDdot_v2, :cublasDdot_v2_64,
     @eval begin
         function $jname(n::Integer,
                         x::StridedCuVecOrDenseMat{$elty},
-                        y::StridedCuVecOrDenseMat{$elty},
-                        result::Ref{$elty},
+                y::StridedCuVecOrDenseMat{$elty},
+                result::Ref{$elty},
             )
             if CUBLAS.version() >= v"12.0"
                 $fname_64(handle(), n, x, stride(x, 1), y, stride(y, 1), result)
@@ -339,7 +339,7 @@ for (fname, fname_64, elty) in ((:cublasDaxpy_v2, :cublasDaxpy_v2_64, :Float64),
                                 (:cublasCaxpy_v2, :cublasCaxpy_v2_64, :ComplexF32))
     @eval begin
         function axpy!(n::Integer,
-                       alpha::Ref{$elty},
+                alpha::Ref{$elty},
                        dx::StridedCuVecOrDenseMat{$elty},
                        dy::StridedCuVecOrDenseMat{$elty})
             if CUBLAS.version() >= v"12.0"
@@ -400,9 +400,9 @@ for (fname, fname_64, elty, cty, sty) in (
         function rot!(n::Integer,
                       x::StridedCuVecOrDenseMat{$elty},
                       y::StridedCuVecOrDenseMat{$elty},
-                      c::Ref{$cty},
-                      s::Ref{$sty},
-                     )
+                c::Ref{$cty},
+                s::Ref{$sty},
+            )
             if CUBLAS.version() >= v"12.0"
                 $fname_64(handle(), n, x, stride(x, 1), y, stride(y, 1), c, s)
             else
@@ -473,9 +473,9 @@ for (fname, fname_64, elty) in ((:cublasIdamax_v2, :cublasIdamax_v2_64, :Float64
                                 (:cublasIcamax_v2, :cublasIcamax_v2_64, :ComplexF32))
     @eval begin
         function iamax(n::Integer,
-                       dx::StridedCuVecOrDenseMat{$elty},
-                       result::Ref{Ti},
-                      ) where {Ti <: Integer}
+                dx::StridedCuVecOrDenseMat{$elty},
+                result::Ref{Ti},
+            ) where {Ti <: Integer}
             if CUBLAS.version() >= v"12.0"
                 $fname_64(handle(), n, dx, stride(dx, 1), result)
             else
@@ -494,9 +494,9 @@ for (fname, fname_64, elty) in ((:cublasIdamin_v2, :cublasIdamin_v2_64, :Float64
                                 (:cublasIcamin_v2, :cublasIcamin_v2_64, :ComplexF32))
     @eval begin
         function iamin(n::Integer,
-                       dx::StridedCuVecOrDenseMat{$elty},
-                       result::Ref{Ti},
-                      ) where {Ti <: Integer}
+                dx::StridedCuVecOrDenseMat{$elty},
+                result::Ref{Ti},
+            ) where {Ti <: Integer}
             if CUBLAS.version() >= v"12.0"
                 $fname_64(handle(), n, dx, stride(dx, 1), result)
             else
@@ -530,10 +530,10 @@ for (fname, fname_64, elty) in ((:cublasDgemv_v2, :cublasDgemv_v2_64, :Float64),
                                 (:cublasCgemv_v2, :cublasCgemv_v2_64, :ComplexF32))
     @eval begin
         function gemv!(trans::Char,
-                       alpha::Ref{$elty},
+                alpha::Ref{$elty},
                        A::StridedCuMatrix{$elty},
                        x::StridedCuVector{$elty},
-                       beta::Ref{$elty},
+                beta::Ref{$elty},
                        y::StridedCuVector{$elty})
             # handle trans
             m,n = size(A)
@@ -562,7 +562,7 @@ end
 function gemv(trans::Char, alpha::Ref{T}, A::StridedCuMatrix{T}, x::StridedCuVector{T}) where {T}
     return gemv!(trans, alpha, A, x, CuRef{T}(zero(T)), similar(x, size(A, (trans == 'N' ? 1 : 2))))
 end
-function gemv(trans::Char, alpha::Number, A::StridedCuMatrix{T}, x::StridedCuVector{T}) where T
+function gemv(trans::Char, alpha::Number, A::StridedCuMatrix{T}, x::StridedCuVector{T}) where {T}
     gemv!(trans, alpha, A, x, zero(T), similar(x, size(A, (trans == 'N' ? 1 : 2))))
 end
 # should this be async?
@@ -580,12 +580,12 @@ for (fname, fname_64, eltyin, eltyout, eltyconst) in (
     )
     @eval begin
         function gemv_batched!(trans::Char,
-                               alpha::Ref{$eltyconst},
-                               A::Vector{<:StridedCuMatrix{$eltyin}},
-                               x::Vector{<:StridedCuVector{$eltyin}},
-                               beta::Ref{$eltyconst},
-                               y::Vector{<:StridedCuVector{$eltyout}}
-                              )
+                alpha::Ref{$eltyconst},
+                A::Vector{<:StridedCuMatrix{$eltyin}},
+                x::Vector{<:StridedCuVector{$eltyin}},
+                beta::Ref{$eltyconst},
+                y::Vector{<:StridedCuVector{$eltyout}}
+            )
             if length(A) != length(x) || length(A) != length(y)
                 throw(DimensionMismatch("Lengths of inputs must be the same"))
             end
@@ -616,13 +616,13 @@ for (fname, fname_64, eltyin, eltyout, eltyconst) in (
             y
         end
         function gemv_batched!(
-                               trans::Char,
-                               alpha::Number,
-                               A::Vector{<:StridedCuMatrix{$eltyin}},
-                               x::Vector{<:StridedCuVector{$eltyin}},
-                               beta::Number,
-                               y::Vector{<:StridedCuVector{$eltyout}}
-                              )
+                trans::Char,
+                alpha::Number,
+                A::Vector{<:StridedCuMatrix{$eltyin}},
+                x::Vector{<:StridedCuVector{$eltyin}},
+                beta::Number,
+                y::Vector{<:StridedCuVector{$eltyout}}
+            )
             gpu_α = CuRef{$eltyconst}(alpha)
             gpu_β = CuRef{$eltyconst}(beta)
             y = gemv_batched!(trans, gpu_α, A, x, gpu_β, y)
@@ -642,12 +642,12 @@ for (fname, fname_64, eltyin, eltyout, eltyconst) in (
     )
     @eval begin
         function gemv_strided_batched!(trans::Char,
-                                       alpha::Ref{$eltyconst},
-                                       A::AbstractArray{$eltyin, 3},
-                                       x::AbstractArray{$eltyin, 2},
-                                       beta::Ref{$eltyconst},
-                                       y::AbstractArray{$eltyout, 2}
-                                      )
+                alpha::Ref{$eltyconst},
+                A::AbstractArray{$eltyin, 3},
+                x::AbstractArray{$eltyin, 2},
+                beta::Ref{$eltyconst},
+                y::AbstractArray{$eltyout, 2}
+            )
             if size(A, 3) != size(x, 2) || size(A, 3) != size(y, 2)
                 throw(DimensionMismatch("Batch sizes must be equal for all inputs"))
             end
@@ -672,13 +672,13 @@ for (fname, fname_64, eltyin, eltyout, eltyconst) in (
             y
         end
         function gemv_strided_batched!(
-                                       trans::Char,
-                                       alpha::Number,
-                                       A::AbstractArray{$eltyin, 3},
-                                       x::AbstractArray{$eltyin, 2},
-                                       beta::Number,
-                                       y::AbstractArray{$eltyout, 2}
-                                      )
+                trans::Char,
+                alpha::Number,
+                A::AbstractArray{$eltyin, 3},
+                x::AbstractArray{$eltyin, 2},
+                beta::Number,
+                y::AbstractArray{$eltyout, 2}
+            )
             gpu_α = CuRef{$eltyconst}(alpha)
             gpu_β = CuRef{$eltyconst}(beta)
             y = gemv_strided_batched!(trans, gpu_α, A, x, gpu_β, y)
@@ -698,10 +698,10 @@ for (fname, fname_64, elty) in ((:cublasDgbmv_v2, :cublasDgbmv_v2_64, :Float64),
                        m::Integer,
                        kl::Integer,
                        ku::Integer,
-                       alpha::Ref{$elty},
+                alpha::Ref{$elty},
                        A::StridedCuMatrix{$elty},
                        x::StridedCuVector{$elty},
-                       beta::Ref{$elty},
+                beta::Ref{$elty},
                        y::StridedCuVector{$elty})
             n = size(A,2)
             # check dimensions
@@ -717,16 +717,17 @@ for (fname, fname_64, elty) in ((:cublasDgbmv_v2, :cublasDgbmv_v2_64, :Float64),
             end
             y
         end
-        function gbmv!(trans::Char,
-                       m::Integer,
-                       kl::Integer,
-                       ku::Integer,
-                       alpha::Number,
-                       A::StridedCuMatrix{$elty},
-                       x::StridedCuVector{$elty},
-                       beta::Number,
-                       y::StridedCuVector{$elty}
-                      )
+        function gbmv!(
+                trans::Char,
+                m::Integer,
+                kl::Integer,
+                ku::Integer,
+                alpha::Number,
+                A::StridedCuMatrix{$elty},
+                x::StridedCuVector{$elty},
+                beta::Number,
+                y::StridedCuVector{$elty}
+            )
 
             gpu_α = CuRef{$elty}(alpha)
             gpu_β = CuRef{$elty}(beta)
@@ -736,8 +737,10 @@ for (fname, fname_64, elty) in ((:cublasDgbmv_v2, :cublasDgbmv_v2_64, :Float64),
         end
     end
 end
-function gbmv(trans::Char, m::Integer, kl::Integer, ku::Integer, alpha::Ref{T},
-              A::StridedCuMatrix{T}, x::StridedCuVector{T}) where {T}
+function gbmv(
+        trans::Char, m::Integer, kl::Integer, ku::Integer, alpha::Ref{T},
+        A::StridedCuMatrix{T}, x::StridedCuVector{T}
+    ) where {T}
     # TODO: fix gbmv bug in julia
     n = size(A, 2)
     leny = trans == 'N' ? m : n
@@ -760,10 +763,10 @@ for (fname, fname_64, elty) in ((:cublasDspmv_v2, :cublasDspmv_v2_64, :Float64),
                                 (:cublasSspmv_v2, :cublasSspmv_v2_64, :Float32))
     @eval begin
         function spmv!(uplo::Char,
-                       alpha::Ref{$elty},
+                alpha::Ref{$elty},
                        AP::StridedCuVector{$elty},
                        x::StridedCuVector{$elty},
-                       beta::Ref{$elty},
+                beta::Ref{$elty},
                        y::StridedCuVector{$elty})
             n = round(Int, (sqrt(8*length(AP))-1)/2)
             if n != length(x) || n != length(y) throw(DimensionMismatch("")) end
@@ -778,21 +781,24 @@ for (fname, fname_64, elty) in ((:cublasDspmv_v2, :cublasDspmv_v2_64, :Float64),
         end
     end
 end
-function spmv!(uplo::Char,
-               alpha::Number,
-               AP::StridedCuVector{T},
-               x::StridedCuVector{T},
-               beta::Number,
-               y::StridedCuVector{T}
-              ) where {T}
+function spmv!(
+        uplo::Char,
+        alpha::Number,
+        AP::StridedCuVector{T},
+        x::StridedCuVector{T},
+        beta::Number,
+        y::StridedCuVector{T}
+    ) where {T}
     gpu_α = CuRef{T}(alpha)
     gpu_β = CuRef{T}(beta)
     y = spmv!(uplo, gpu_α, AP, x, gpu_β, y)
     synchronize()
     return y
 end
-function spmv(uplo::Char, alpha::Ref{T},
-              AP::StridedCuVector{T}, x::StridedCuVector{T}) where {T}
+function spmv(
+        uplo::Char, alpha::Ref{T},
+        AP::StridedCuVector{T}, x::StridedCuVector{T}
+    ) where {T}
     return spmv!(uplo, alpha, AP, x, CuRef{T}(zero(T)), similar(x))
 end
 function spmv(uplo::Char, alpha::Number,
@@ -811,10 +817,10 @@ for (fname, fname_64, elty) in ((:cublasDsymv_v2, :cublasDsymv_v2_64, :Float64),
     # Note that the complex symv are not BLAS but auiliary functions in LAPACK
     @eval begin
         function symv!(uplo::Char,
-                       alpha::Ref{$elty},
+                alpha::Ref{$elty},
                        A::StridedCuMatrix{$elty},
                        x::StridedCuVector{$elty},
-                       beta::Ref{$elty},
+                beta::Ref{$elty},
                        y::StridedCuVector{$elty})
             m, n = size(A)
             if m != n throw(DimensionMismatch("Matrix A is $m by $n but must be square")) end
@@ -865,10 +871,10 @@ for (fname, fname_64, elty) in ((:cublasZhemv_v2, :cublasZhemv_v2_64, :ComplexF6
                                 (:cublasChemv_v2, :cublasChemv_v2_64, :ComplexF32))
     @eval begin
         function hemv!(uplo::Char,
-                       alpha::Ref{$elty},
+                alpha::Ref{$elty},
                        A::StridedCuMatrix{$elty},
                        x::StridedCuVector{$elty},
-                       beta::Ref{$elty},
+                beta::Ref{$elty},
                        y::StridedCuVector{$elty})
             # TODO: fix dimension check bug in julia
             m, n = size(A)
@@ -923,10 +929,10 @@ for (fname, fname_64, elty) in ((:cublasDsbmv_v2, :cublasDsbmv_v2_64, :Float64),
     @eval begin
         function sbmv!(uplo::Char,
                        k::Integer,
-                       alpha::Ref{$elty},
+                alpha::Ref{$elty},
                        A::StridedCuMatrix{$elty},
                        x::StridedCuVector{$elty},
-                       beta::Ref{$elty},
+                beta::Ref{$elty},
                        y::StridedCuVector{$elty})
             m, n = size(A)
             #if m != n throw(DimensionMismatch("Matrix A is $m by $n but must be square")) end
@@ -982,10 +988,10 @@ for (fname, fname_64, elty) in ((:cublasZhbmv_v2, :cublasZhbmv_v2_64, :ComplexF6
     @eval begin
         function hbmv!(uplo::Char,
                        k::Integer,
-                       alpha::Ref{$elty},
+                alpha::Ref{$elty},
                        A::StridedCuMatrix{$elty},
                        x::StridedCuVector{$elty},
-                       beta::Ref{$elty},
+                beta::Ref{$elty},
                        y::StridedCuVector{$elty})
             m, n = size(A)
             if !(1<=(1+k)<=n) throw(DimensionMismatch("Incorrect number of bands")) end
@@ -1169,7 +1175,7 @@ for (fname, fname_64, elty) in ((:cublasDger_v2, :cublasDger_v2_64, :Float64),
                                 (:cublasCgerc_v2, :cublasCgerc_v2_64, :ComplexF32))
     @eval begin
         function ger!(
-                      alpha::Ref{$elty},
+                alpha::Ref{$elty},
                       x::StridedCuVector{$elty},
                       y::StridedCuVector{$elty},
                       A::StridedCuMatrix{$elty})
@@ -1205,7 +1211,7 @@ for (fname, fname_64, elty) in ((:cublasDspr_v2, :cublasDspr_v2_64, :Float64),
                                 (:cublasSspr_v2, :cublasSspr_v2_64, :Float32))
     @eval begin
         function spr!(uplo::Char,
-                      alpha::Ref{$elty},
+                alpha::Ref{$elty},
                       x::StridedCuVector{$elty},
                       AP::StridedCuVector{$elty})
             n = round(Int, (sqrt(8*length(AP))-1)/2)
@@ -1239,7 +1245,7 @@ for (fname, fname_64, elty) in ((:cublasDsyr_v2, :cublasDsyr_v2_64, :Float64),
                                 (:cublasCsyr_v2, :cublasCsyr_v2_64, :ComplexF32))
     @eval begin
         function syr!(uplo::Char,
-                      alpha::Ref{$elty},
+                alpha::Ref{$elty},
                       x::StridedCuVector{$elty},
                       A::StridedCuMatrix{$elty})
             m, n = size(A)
@@ -1275,7 +1281,7 @@ for (fname, fname_64, elty, relty) in (
     )
     @eval begin
         function her!(uplo::Char,
-                      alpha::Ref{$relty},
+                alpha::Ref{$relty},
                       x::StridedCuVector{$elty},
                       A::StridedCuMatrix{$elty})
             m, n = size(A)
@@ -1309,11 +1315,11 @@ for (fname, fname_64, elty) in ((:cublasZher2_v2, :cublasZher2_v2_64, :ComplexF6
                                 (:cublasCher2_v2, :cublasCher2_v2_64, :ComplexF32))
     @eval begin
         function her2!(uplo::Char,
-                       alpha::Ref{$elty},
-                       x::StridedCuVector{$elty},
-                       y::StridedCuVector{$elty},
-                       A::StridedCuMatrix{$elty}
-                      )
+                alpha::Ref{$elty},
+                x::StridedCuVector{$elty},
+                y::StridedCuVector{$elty},
+                A::StridedCuMatrix{$elty}
+            )
             m, n = size(A)
             m == n || throw(DimensionMismatch("Matrix A is $m by $n but must be square"))
             length(x) == n || throw(DimensionMismatch("Length of vector must be the same as the matrix dimensions"))
@@ -1353,10 +1359,10 @@ for (fname, fname_64, elty) in ((:cublasDgemm_v2, :cublasDgemm_v2_64, :Float64),
     @eval begin
         function gemm!(transA::Char,
                        transB::Char,
-                       alpha::Ref{$elty},
+                alpha::Ref{$elty},
                        A::StridedCuVecOrMat{$elty},
                        B::StridedCuVecOrMat{$elty},
-                       beta::Ref{$elty},
+                beta::Ref{$elty},
                        C::StridedCuVecOrMat{$elty})
             m = size(A, transA == 'N' ? 1 : 2)
             k = size(A, transA == 'N' ? 2 : 1)
@@ -1494,10 +1500,10 @@ function gemmExComputeType(TA, TB, TC, m, k, n)
 end
 
 function gemmEx!(transA::Char, transB::Char,
-                 @nospecialize(alpha::Ref),
+        @nospecialize(alpha::Ref),
                  @nospecialize(A::StridedCuVecOrMat),
                  @nospecialize(B::StridedCuVecOrMat),
-                 @nospecialize(beta::Ref),
+        @nospecialize(beta::Ref),
                  @nospecialize(C::StridedCuVecOrMat);
                  algo::cublasGemmAlgo_t=CUBLAS_GEMM_DEFAULT)
     m = size(A, transA == 'N' ? 1 : 2)
@@ -1552,10 +1558,10 @@ end
 
 # TODO for device mode pointers
 function gemmBatchedEx!(transA::Char, transB::Char,
-                 @nospecialize(alpha::Ref),
+        @nospecialize(alpha::Ref),
                  @nospecialize(A::Vector{<:StridedCuVecOrMat}),
                  @nospecialize(B::Vector{<:StridedCuVecOrMat}),
-                 @nospecialize(beta::Ref),
+        @nospecialize(beta::Ref),
                  @nospecialize(C::Vector{<:StridedCuVecOrMat});
                  algo::cublasGemmAlgo_t=CUBLAS_GEMM_DEFAULT)
     if length(A) != length(B) || length(A) != length(C)
@@ -1623,11 +1629,11 @@ function gemmBatchedEx!(
 end
 
 function gemmStridedBatchedEx!(
-                 transA::Char, transB::Char,
-                 @nospecialize(alpha::Ref),
+        transA::Char, transB::Char,
+        @nospecialize(alpha::Ref),
                  @nospecialize(A::AbstractArray{Ta, 3}),
                  @nospecialize(B::AbstractArray{Tb, 3}),
-                 @nospecialize(beta::Ref),
+        @nospecialize(beta::Ref),
                  @nospecialize(C::AbstractArray{Tc, 3});
                  algo::cublasGemmAlgo_t=CUBLAS_GEMM_DEFAULT) where {Ta, Tb, Tc}
     if size(A, 3) != size(B, 3) || size(A, 3) != size(C, 3)
@@ -1866,10 +1872,10 @@ for (fname, fname_64, elty) in ((:cublasDgemmBatched, :cublasDgemmBatched_64, :F
     @eval begin
         function gemm_batched!(transA::Char,
                                transB::Char,
-                               alpha::Ref{$elty},
+                alpha::Ref{$elty},
                                A::Vector{<:StridedCuMatrix{$elty}},
                                B::Vector{<:StridedCuMatrix{$elty}},
-                               beta::Ref{$elty},
+                beta::Ref{$elty},
                                C::Vector{<:StridedCuMatrix{$elty}})
             if length(A) != length(B) || length(A) != length(C)
                 throw(DimensionMismatch(""))
@@ -1949,10 +1955,10 @@ for (fname, fname_64, elty) in ((:cublasDgemmStridedBatched, :cublasDgemmStrided
     @eval begin
         function gemm_strided_batched!(transA::Char,
                                transB::Char,
-                               alpha::Ref{$elty},
+                alpha::Ref{$elty},
                                A::AbstractArray{$elty, 3}, # allow PermutedDimsArray
                                B::AbstractArray{$elty, 3},
-                               beta::Ref{$elty},
+                beta::Ref{$elty},
                                C::AbstractArray{$elty, 3})
            m = size(A, transA == 'N' ? 1 : 2)
            k = size(A, transA == 'N' ? 2 : 1)
@@ -2032,10 +2038,10 @@ for (fname, fname_64, elty) in ((:cublasDsymm_v2, :cublasDsymm_v2_64, :Float64),
     @eval begin
         function symm!(side::Char,
                        uplo::Char,
-                       alpha::Ref{$elty},
+                alpha::Ref{$elty},
                        A::StridedCuMatrix{$elty},
                        B::StridedCuMatrix{$elty},
-                       beta::Ref{$elty},
+                beta::Ref{$elty},
                        C::StridedCuMatrix{$elty})
             k, nA = size(A)
             if k != nA throw(DimensionMismatch("Matrix A must be square")) end
@@ -2094,9 +2100,9 @@ for (fname, fname_64, elty) in ((:cublasDsyrk_v2, :cublasDsyrk_v2_64, :Float64),
     @eval begin
         function syrk!(uplo::Char,
                        trans::Char,
-                       alpha::Ref{$elty},
+                alpha::Ref{$elty},
                        A::StridedCuVecOrMat{$elty},
-                       beta::Ref{$elty},
+                beta::Ref{$elty},
                        C::StridedCuMatrix{$elty})
             mC, n = size(C)
             if mC != n throw(DimensionMismatch("C must be square")) end
@@ -2147,10 +2153,10 @@ for (fname, fname_64, elty) in ((:cublasDsyrkx, :cublasDsyrkx_64, :Float64),
     @eval begin
         function syrkx!(uplo::Char,
                        trans::Char,
-                       alpha::Ref{$elty},
+                alpha::Ref{$elty},
                        A::StridedCuVecOrMat{$elty},
                        B::StridedCuVecOrMat{$elty},
-                       beta::Ref{$elty},
+                beta::Ref{$elty},
                        C::StridedCuMatrix{$elty})
             mC, n = size(C)
             if mC != n throw(DimensionMismatch("C must be square")) end
@@ -2206,10 +2212,10 @@ for (fname, fname_64, elty) in ((:cublasZhemm_v2, :cublasZhemm_v2_64, :ComplexF6
     @eval begin
         function hemm!(side::Char,
                        uplo::Char,
-                       alpha::Ref{$elty},
+                alpha::Ref{$elty},
                        A::StridedCuMatrix{$elty},
                        B::StridedCuMatrix{$elty},
-                       beta::Ref{$elty},
+                beta::Ref{$elty},
                        C::StridedCuMatrix{$elty})
             mA, nA = size(A)
             m, n = size(B)
@@ -2269,9 +2275,9 @@ for (fname, fname_64, elty, relty) in (
     @eval begin
         function herk!(uplo::Char,
                        trans::Char,
-                       alpha::Ref{$relty},
+                alpha::Ref{$relty},
                        A::StridedCuVecOrMat{$elty},
-                       beta::Ref{$relty},
+                beta::Ref{$relty},
                        C::StridedCuMatrix{$elty})
             mC, n = size(C)
             if mC != n throw(DimensionMismatch("C must be square")) end
@@ -2328,10 +2334,10 @@ for (fname, fname_64, elty) in ((:cublasDsyr2k_v2, :cublasDsyr2k_v2_64, :Float64
     @eval begin
         function syr2k!(uplo::Char,
                         trans::Char,
-                        alpha::Ref{$elty},
+                alpha::Ref{$elty},
                         A::StridedCuVecOrMat{$elty},
                         B::StridedCuVecOrMat{$elty},
-                        beta::Ref{$elty},
+                beta::Ref{$elty},
                         C::StridedCuMatrix{$elty})
             # TODO: check size of B in julia (syr2k!)
             m, n = size(C)
@@ -2387,7 +2393,7 @@ function syr2k(uplo::Char,
                B::StridedCuVecOrMat)
     T = eltype(A)
     n = size(A, trans == 'N' ? 1 : 2)
-    syr2k!(uplo, trans, convert(T, alpha), A, B, zero(T), similar(A, T, (n, n)))
+    return syr2k!(uplo, trans, convert(T, alpha), A, B, zero(T), similar(A, T, (n, n)))
 end
 function syr2k(uplo::Char, trans::Char, A::StridedCuVecOrMat, B::StridedCuVecOrMat)
     syr2k(uplo, trans, one(eltype(A)), A, B)
@@ -2401,10 +2407,10 @@ for (fname, fname_64, elty, relty) in (
     @eval begin
         function her2k!(uplo::Char,
                         trans::Char,
-                        alpha::Ref{$elty},
+                alpha::Ref{$elty},
                         A::StridedCuVecOrMat{$elty},
                         B::StridedCuVecOrMat{$elty},
-                        beta::Ref{$relty},
+                beta::Ref{$relty},
                         C::StridedCuMatrix{$elty})
             # TODO: check size of B in julia (her2k!)
             m, n = size(C)
@@ -2478,7 +2484,7 @@ for (mmname, smname, elty) in
                        uplo::Char,
                        transa::Char,
                        diag::Char,
-                       alpha::Ref{$elty},
+                alpha::Ref{$elty},
                        A::StridedCuMatrix{$elty},
                        B::StridedCuMatrix{$elty},
                        C::StridedCuMatrix{$elty})
@@ -2500,7 +2506,7 @@ for (mmname, smname, elty) in
                        uplo::Char,
                        transa::Char,
                        diag::Char,
-                       alpha::Ref{$elty},
+                alpha::Ref{$elty},
                        A::StridedCuMatrix{$elty},
                        B::StridedCuMatrix{$elty})
             m, n = size(B)
@@ -2565,7 +2571,7 @@ for (fname, fname_64, elty) in ((:cublasDtrsmBatched, :cublasDtrsmBatched_64, :F
                                uplo::Char,
                                transa::Char,
                                diag::Char,
-                               alpha::Ref{$elty},
+                alpha::Ref{$elty},
                                A::Vector{<:StridedCuMatrix{$elty}},
                                B::Vector{<:StridedCuMatrix{$elty}})
             if length(A) != length(B)
@@ -2621,9 +2627,9 @@ for (fname, fname_64, elty) in ((:cublasDgeam, :cublasDgeam_64, :Float64),
     @eval begin
         function geam!(transa::Char,
                        transb::Char,
-                       alpha::Ref{$elty},
+                alpha::Ref{$elty},
                        A::StridedCuMatrix{$elty},
-                       beta::Ref{$elty},
+                beta::Ref{$elty},
                        B::StridedCuMatrix{$elty},
                        C::StridedCuMatrix{$elty})
             mA, nA = size(A)
@@ -2861,8 +2867,9 @@ for (fname, elty) in ((:cublasDgetriBatched, :Float64),
         end
 
         function getri_batched!(n, Aptrs::CuVector{CuPtr{$elty}},
-                                lda, Cptrs::CuVector{CuPtr{$elty}},ldc,
-                                pivotArray::CuArray{Cint})
+                lda, Cptrs::CuVector{CuPtr{$elty}}, ldc,
+                pivotArray::CuArray{Cint}
+            )
             batchSize = length(Aptrs)
             info = CuArray{Cint}(undef, batchSize)
             $fname(handle(), n, Aptrs, lda, pivotArray, Cptrs, ldc, info, batchSize)
diff --git a/test/libraries/cublas/level3.jl b/test/libraries/cublas/level3.jl
index 65b71c6a4..caf245a5b 100644
--- a/test/libraries/cublas/level3.jl
+++ b/test/libraries/cublas/level3.jl
@@ -352,12 +352,12 @@ k = 13
             @testset "herk!" begin
                 alpha = rand(elty)
                 beta = rand(elty)
-                A = rand(elty,m,m)
+                A = rand(elty, m, m)
                 hA = A + A'
                 d_A = CuArray(A)
                 d_C = CuArray(hA)
-                CUBLAS.herk!('U','N',real(alpha),d_A,real(beta),d_C)
-                C = real(alpha)*(A*A') + real(beta)*hA
+                CUBLAS.herk!('U', 'N', real(alpha), d_A, real(beta), d_C)
+                C = real(alpha) * (A * A') + real(beta) * hA
                 C = triu(C)
                 # move to host and compare
                 h_C = Array(d_C)
@@ -365,10 +365,10 @@ k = 13
                 @test C ≈ h_C
             end
             @testset "herk" begin
-                A = rand(elty,m,m)
+                A = rand(elty, m, m)
                 d_A = CuArray(A)
-                d_C = CUBLAS.herk('U','N',d_A)
-                C = A*A'
+                d_C = CUBLAS.herk('U', 'N', d_A)
+                C = A * A'
                 C = triu(C)
                 # move to host and compare
                 h_C = Array(d_C)
diff --git a/test/libraries/cublas/level3_gemm.jl b/test/libraries/cublas/level3_gemm.jl
index 6e04e8c42..6cccc2967 100644
--- a/test/libraries/cublas/level3_gemm.jl
+++ b/test/libraries/cublas/level3_gemm.jl
@@ -220,12 +220,12 @@ k = 13
             sB = sB + transpose(sB)
 
             for (TRa, ta, TRb, tb, TRc, a_func, b_func) in (
-                (UpperTriangular, identity,  LowerTriangular, identity,  Matrix, triu, tril),
-                (LowerTriangular, identity,  UpperTriangular, identity,  Matrix, tril, triu),
-                (UpperTriangular, identity,  UpperTriangular, transpose, Matrix, triu, triu),
-                (UpperTriangular, transpose, UpperTriangular, identity,  Matrix, triu, triu),
-                (LowerTriangular, identity,  LowerTriangular, transpose, Matrix, tril, tril),
-                (LowerTriangular, transpose, LowerTriangular, identity,  Matrix, tril, tril),
+                    (UpperTriangular, identity, LowerTriangular, identity, Matrix, triu, tril),
+                    (LowerTriangular, identity, UpperTriangular, identity, Matrix, tril, triu),
+                    (UpperTriangular, identity, UpperTriangular, transpose, Matrix, triu, triu),
+                    (UpperTriangular, transpose, UpperTriangular, identity, Matrix, triu, triu),
+                    (LowerTriangular, identity, LowerTriangular, transpose, Matrix, tril, tril),
+                    (LowerTriangular, transpose, LowerTriangular, identity, Matrix, tril, tril),
                 )
 
                 A = copy(sA) |> TRa

@maleadt
Copy link
Member

maleadt commented Jan 20, 2025

CI failures seem relevant.

Feel free to ignore the formatter; I made it less spammy 😉

@kshyatt
Copy link
Contributor Author

kshyatt commented Jan 23, 2025

I really do not know what is up with the 1.11 failure, it looks alloc_cache related?

@maleadt
Copy link
Member

maleadt commented Jan 25, 2025

Rebase to get rid of CI failures?

@kshyatt
Copy link
Contributor Author

kshyatt commented Jan 25, 2025 via email

@kshyatt kshyatt force-pushed the ksh/device_side branch 2 times, most recently from 804a967 to bcd41c1 Compare January 25, 2025 21:47
@kshyatt
Copy link
Contributor Author

kshyatt commented Jan 25, 2025

Gotta admit I'm a bit mystified here as I cannot reproduce these trmm faliures locally.

If I run only the libraries/cublas tests or even just libraries using the runtests.jl argument support, everything succeeds locally. If I run the full test suite, I start seeing intermittent illegal access errors/incorrect results in syr2k!. Weird!

lib/cublas/wrappers.jl Outdated Show resolved Hide resolved
@kshyatt kshyatt force-pushed the ksh/device_side branch 3 times, most recently from ec80895 to c3ad223 Compare February 4, 2025 22:59
Copy link

codecov bot commented Feb 4, 2025

Codecov Report

Attention: Patch coverage is 95.48872% with 6 lines in your changes missing coverage. Please review.

Project coverage is 73.58%. Comparing base (4d85f27) to head (252cfe6).

Files with missing lines Patch % Lines
lib/cublas/wrappers.jl 95.12% 6 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #2616      +/-   ##
==========================================
- Coverage   73.58%   73.58%   -0.01%     
==========================================
  Files         157      157              
  Lines       15242    15275      +33     
==========================================
+ Hits        11216    11240      +24     
- Misses       4026     4035       +9     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@kshyatt
Copy link
Contributor Author

kshyatt commented Feb 4, 2025

Tests passed! Pushed a fix for the wrapper generator to have cublasXt functions use Ref rather than CuRef for their scalars - since the Xt methods are multi-GPU, and cublasXt doesn't have this pointer mode business, I think we should keep the CPU-side scalar for now.

@kshyatt kshyatt marked this pull request as ready for review February 4, 2025 23:53
@maleadt maleadt changed the title RFC: Use non-blocking device side pointer mode in CUBLAS, with fallbacks Switch CUBLAS to device-side pointer mode Feb 6, 2025
@maleadt
Copy link
Member

maleadt commented Feb 6, 2025

Pushed some simplificatons, e.g., supporting CuRef(x)[]. The new wrappers from #2642 also broke this PR, in part because there's still many PtrOrCuPtrs in there that are now illegal.

Finally, I don't feel like we should simply remove the group batched functionality. Since the handle is task-local anyway, I'd rather we temporarily switch pointer modes for these API calls and keep the functionality working.

# so spawn a new worker when the test did so
if test in ["core/initialization", "core/cudadrv"]
p = recycle_worker(p)
end
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah cool I completely missed that this function existed

@kshyatt
Copy link
Contributor Author

kshyatt commented Feb 6, 2025

Are we still "needs changes" on this or can it be squash-merged now that it passed CI?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cuda libraries Stuff about CUDA library wrappers. needs changes Changes are needed. performance How fast can we go?
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants